{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "730d7c26",
   "metadata": {},
   "outputs": [],
   "source": [
    "%run -i \"../util/util_simple_classifier.ipynb\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e0ef169e",
   "metadata": {},
   "source": [
    "# Encode sentences using word2vec"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "e54cc0eb",
   "metadata": {},
   "outputs": [],
   "source": [
    "import gensim"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "9231c5ea",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = gensim.models.KeyedVectors.load_word2vec_format('../data/GoogleNews-vectors-negative300.bin.gz', binary=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "27b0a447",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[ 1.25976562e-01  2.97851562e-02  8.60595703e-03  1.39648438e-01\n",
      " -2.56347656e-02 -3.61328125e-02  1.11816406e-01 -1.98242188e-01\n",
      "  5.12695312e-02  3.63281250e-01 -2.42187500e-01 -3.02734375e-01\n",
      " -1.77734375e-01 -2.49023438e-02 -1.67968750e-01 -1.69921875e-01\n",
      "  3.46679688e-02  5.21850586e-03  4.63867188e-02  1.28906250e-01\n",
      "  1.36718750e-01  1.12792969e-01  5.95703125e-02  1.36718750e-01\n",
      "  1.01074219e-01 -1.76757812e-01 -2.51953125e-01  5.98144531e-02\n",
      "  3.41796875e-01 -3.11279297e-02  1.04492188e-01  6.17675781e-02\n",
      "  1.24511719e-01  4.00390625e-01 -3.22265625e-01  8.39843750e-02\n",
      "  3.90625000e-02  5.85937500e-03  7.03125000e-02  1.72851562e-01\n",
      "  1.38671875e-01 -2.31445312e-01  2.83203125e-01  1.42578125e-01\n",
      "  3.41796875e-01 -2.39257812e-02 -1.09863281e-01  3.32031250e-02\n",
      " -5.46875000e-02  1.53198242e-02 -1.62109375e-01  1.58203125e-01\n",
      " -2.59765625e-01  2.01416016e-02 -1.63085938e-01  1.35803223e-03\n",
      " -1.44531250e-01 -5.68847656e-02  4.29687500e-02 -2.46582031e-02\n",
      "  1.85546875e-01  4.47265625e-01  9.58251953e-03  1.31835938e-01\n",
      "  9.86328125e-02 -1.85546875e-01 -1.00097656e-01 -1.33789062e-01\n",
      " -1.25000000e-01  2.83203125e-01  1.23046875e-01  5.32226562e-02\n",
      " -1.77734375e-01  8.59375000e-02 -2.18505859e-02  2.05078125e-02\n",
      " -1.39648438e-01  2.51464844e-02  1.38671875e-01 -1.05468750e-01\n",
      "  1.38671875e-01  8.88671875e-02 -7.51953125e-02 -2.13623047e-02\n",
      "  1.72851562e-01  4.63867188e-02 -2.65625000e-01  8.91113281e-03\n",
      "  1.49414062e-01  3.78417969e-02  2.38281250e-01 -1.24511719e-01\n",
      " -2.17773438e-01 -1.81640625e-01  2.97851562e-02  5.71289062e-02\n",
      " -2.89306641e-02  1.24511719e-02  9.66796875e-02 -2.31445312e-01\n",
      "  5.81054688e-02  6.68945312e-02  7.08007812e-02 -3.08593750e-01\n",
      " -2.14843750e-01  1.45507812e-01 -4.27734375e-01 -9.39941406e-03\n",
      "  1.54296875e-01 -7.66601562e-02  2.89062500e-01  2.77343750e-01\n",
      " -4.86373901e-04 -1.36718750e-01  3.24218750e-01 -2.46093750e-01\n",
      " -3.03649902e-03 -2.11914062e-01  1.25000000e-01  2.69531250e-01\n",
      "  2.04101562e-01  8.25195312e-02 -2.01171875e-01 -1.60156250e-01\n",
      " -3.78417969e-02 -1.20117188e-01  1.15234375e-01 -4.10156250e-02\n",
      " -3.95507812e-02 -8.98437500e-02  6.34765625e-03  2.03125000e-01\n",
      "  1.86523438e-01  2.73437500e-01  6.29882812e-02  1.41601562e-01\n",
      " -9.81445312e-02  1.38671875e-01  1.82617188e-01  1.73828125e-01\n",
      "  1.73828125e-01 -2.37304688e-01  1.78710938e-01  6.34765625e-02\n",
      "  2.36328125e-01 -2.08984375e-01  8.74023438e-02 -1.66015625e-01\n",
      " -7.91015625e-02  2.43164062e-01 -8.88671875e-02  1.26953125e-01\n",
      " -2.16796875e-01 -1.73828125e-01 -3.59375000e-01 -8.25195312e-02\n",
      " -6.49414062e-02  5.07812500e-02  1.35742188e-01 -7.47070312e-02\n",
      " -1.64062500e-01  1.15356445e-02  4.45312500e-01 -2.15820312e-01\n",
      " -1.11328125e-01 -1.92382812e-01  1.70898438e-01 -1.25000000e-01\n",
      "  2.65502930e-03  1.92382812e-01 -1.74804688e-01  1.39648438e-01\n",
      "  2.92968750e-01  1.13281250e-01  5.95703125e-02 -6.39648438e-02\n",
      "  9.96093750e-02 -2.72216797e-02  1.96533203e-02  4.27246094e-02\n",
      " -2.46093750e-01  6.39648438e-02 -2.25585938e-01 -1.68945312e-01\n",
      "  2.89916992e-03  8.20312500e-02  3.41796875e-01  4.32128906e-02\n",
      "  1.32812500e-01  1.42578125e-01  7.61718750e-02  5.98144531e-02\n",
      " -1.19140625e-01  2.74658203e-03 -6.29882812e-02 -2.72216797e-02\n",
      " -4.82177734e-03 -8.20312500e-02 -2.49023438e-02 -4.00390625e-01\n",
      " -1.06933594e-01  4.24804688e-02  7.76367188e-02 -1.16699219e-01\n",
      "  7.37304688e-02 -9.22851562e-02  1.07910156e-01  1.58203125e-01\n",
      "  4.24804688e-02  1.26953125e-01  3.61328125e-02  2.67578125e-01\n",
      " -1.01074219e-01 -3.02734375e-01 -5.76171875e-02  5.05371094e-02\n",
      "  5.26428223e-04 -2.07031250e-01 -1.38671875e-01 -8.97216797e-03\n",
      " -2.78320312e-02 -1.41601562e-01  2.07031250e-01 -1.58203125e-01\n",
      "  1.27929688e-01  1.49414062e-01 -2.24609375e-02 -8.44726562e-02\n",
      "  1.22558594e-01  2.15820312e-01 -2.13867188e-01 -3.12500000e-01\n",
      " -3.73046875e-01  4.08935547e-03  1.07421875e-01  1.06933594e-01\n",
      "  7.32421875e-02  8.97216797e-03 -3.88183594e-02 -1.29882812e-01\n",
      "  1.49414062e-01 -2.14843750e-01 -1.83868408e-03  9.91210938e-02\n",
      "  1.57226562e-01 -1.14257812e-01 -2.05078125e-01  9.91210938e-02\n",
      "  3.69140625e-01 -1.97265625e-01  3.54003906e-02  1.09375000e-01\n",
      "  1.31835938e-01  1.66992188e-01  2.35351562e-01  1.04980469e-01\n",
      " -4.96093750e-01 -1.64062500e-01 -1.56250000e-01 -5.22460938e-02\n",
      "  1.03027344e-01  2.43164062e-01 -1.88476562e-01  5.07812500e-02\n",
      " -9.37500000e-02 -6.68945312e-02  2.27050781e-02  7.61718750e-02\n",
      "  2.89062500e-01  3.10546875e-01 -5.37109375e-02  2.28515625e-01\n",
      "  2.51464844e-02  6.78710938e-02 -1.21093750e-01 -2.15820312e-01\n",
      " -2.73437500e-01 -3.07617188e-02 -3.37890625e-01  1.53320312e-01\n",
      "  2.33398438e-01 -2.08007812e-01  3.73046875e-01  8.20312500e-02\n",
      "  2.51953125e-01 -7.61718750e-02 -4.66308594e-02 -2.23388672e-02\n",
      "  2.99072266e-02 -5.93261719e-02 -4.66918945e-03 -2.44140625e-01\n",
      " -2.09960938e-01 -2.87109375e-01 -4.54101562e-02 -1.77734375e-01\n",
      " -2.79296875e-01 -8.59375000e-02  9.13085938e-02  2.51953125e-01]\n"
     ]
    }
   ],
   "source": [
    "vec_king = model['king']\n",
    "print(vec_king)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a90c42d1",
   "metadata": {},
   "source": [
    "# Word similarity"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "0f6bf30b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[('apples', 0.720359742641449), ('pear', 0.6450697183609009), ('fruit', 0.6410146355628967), ('berry', 0.6302295327186584), ('pears', 0.613396167755127), ('strawberry', 0.6058260798454285), ('peach', 0.6025872826576233), ('potato', 0.5960935354232788), ('grape', 0.5935863852500916), ('blueberry', 0.5866668224334717), ('cherries', 0.5784382820129395), ('mango', 0.5751855969429016), ('apricot', 0.5727777481079102), ('melon', 0.5719985365867615), ('almond', 0.5704829692840576)]\n",
      "[('tomatoes', 0.8442263007164001), ('lettuce', 0.7069936990737915), ('asparagus', 0.7050934433937073), ('peaches', 0.6938520669937134), ('cherry_tomatoes', 0.6897529363632202), ('strawberry', 0.6888598799705505), ('strawberries', 0.6832595467567444), ('bell_peppers', 0.6813562512397766), ('potato', 0.6784172058105469), ('cantaloupe', 0.6780219078063965), ('celery', 0.675195574760437), ('onion', 0.6740139722824097), ('cucumbers', 0.6706333160400391), ('spinach', 0.6682621240615845), ('cauliflower', 0.6681587100028992)]\n"
     ]
    }
   ],
   "source": [
    "print(model.most_similar(['apple'], topn=15))\n",
    "print(model.most_similar(['tomato'], topn=15))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "00901987",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_word_vectors(sentence, model):\n",
    "    word_vectors = []\n",
    "    for word in sentence:\n",
    "        try:\n",
    "            word_vector = model[word.lower()]\n",
    "            word_vectors.append(word_vector)\n",
    "        except KeyError:\n",
    "            continue\n",
    "    return word_vectors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "0ff4de77",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_sentence_vector(word_vectors):\n",
    "    matrix = np.array(word_vectors)\n",
    "    centroid = np.mean(matrix[:,:], axis=0)\n",
    "    return centroid"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "3e42a660",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Found cached dataset rotten_tomatoes (/home/zhenya/.cache/huggingface/datasets/rotten_tomatoes/default/1.0.0/40d411e45a6ce3484deed7cc15b82a53dad9a72aafd9f86f8f227134bec5ca46)\n",
      "Found cached dataset rotten_tomatoes (/home/zhenya/.cache/huggingface/datasets/rotten_tomatoes/default/1.0.0/40d411e45a6ce3484deed7cc15b82a53dad9a72aafd9f86f8f227134bec5ca46)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0       0.54      0.57      0.55       160\n",
      "           1       0.54      0.51      0.53       160\n",
      "\n",
      "    accuracy                           0.54       320\n",
      "   macro avg       0.54      0.54      0.54       320\n",
      "weighted avg       0.54      0.54      0.54       320\n",
      "\n"
     ]
    }
   ],
   "source": [
    "vectorize = lambda x: get_sentence_vector(get_word_vectors(x, model))\n",
    "(train_df, test_df) = load_train_test_dataset_pd()\n",
    "(X_train, X_test, y_train, y_test) = create_train_test_data(train_df, test_df, vectorize)\n",
    "clf = train_classifier(X_train, y_train)\n",
    "test_classifier(test_df, clf)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "141c57de",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "computer\n"
     ]
    }
   ],
   "source": [
    "words = ['banana', 'apple', 'computer', 'strawberry']\n",
    "print(model.doesnt_match(words))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "3a7df40d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "glass\n"
     ]
    }
   ],
   "source": [
    "word = \"cup\"\n",
    "words = ['glass', 'computer', 'pencil', 'watch']\n",
    "print(model.most_similar_to_given(word, words))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "04bebcd7",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
